﻿using System;
using System.Collections.Generic;
using System.Linq;
using HAPI;
using System.Text.RegularExpressions;
using HuginApiAddonsCS.Policy;

namespace HuginApiAddonsCS.Extensions
{
    public static class NodeExtensions
    {
        /// <summary>
        /// returns just the prefix of a node
        /// </summary>
        /// <param name="node">node</param>
        /// <returns>prefix</returns>
        public static string GetNodePrefix(this Node node)
        {
            return Regex.Match(node.GetName(), @"^[^\d]+").Value;
        }

        /// <summary>
        /// Extracts the node time step value
        /// </summary>
        /// <param name="node">node</param>
        /// <returns>time step value</returns>
        public static int GetNodeTimeStep(this Node node)
        {
            return Convert.ToInt32(Regex.Match(node.GetName(), @"\d+").Value);
        }

        /// <summary>
        /// Gets the state with the highest expected utility
        /// </summary>
        /// <param name="node">node</param>
        /// <returns></returns>
        public static string GetStateWithHighestReward(this DiscreteNode node)
        {
            float maxUtil = 0.0f;
            uint maxState = 0;

            for (uint i = 0; i < node.GetNumberOfStates(); i++)
            {
                if (i == 0)
                {
                    maxUtil = node.GetExpectedUtility(i);
                    maxState = i;
                }
                else if (node.GetExpectedUtility(i) > maxUtil)
                {
                    maxUtil = node.GetExpectedUtility(i);
                    maxState = i;
                }
            }

            return node.GetStateLabel(maxState);
        }

        /// <summary>
        /// Gets the node for the next time step
        /// </summary>
        /// <param name="node">query node</param>
        /// <returns>next node, null if no node exists</returns>
        public static Node GetNextTimeStepNode(this Node node)
        {
            return node.GetHomeDomain().GetNextStepNode(node);
        }
        
        /// <summary>
        /// Resets a model nodes table based on states of parent nodes
        /// </summary>
        /// <param name="node">node to reset</param>
        /// <param name="previousModelNode">previous model node</param>
        /// <param name="observationNode">observation node to compare with</param>
        public static void ResetModelNodeTableFromParents(this DiscreteNode node, DiscreteNode previousModelNode, DiscreteNode observationNode)
        {
            List<Node> parents = node.GetParents();

            if (previousModelNode == null)
            {
                //reset table to all 1's because no parents to compare with
                float[] data = node.GetTable().GetData();
                for (int i = 0; i < data.Count(); i++)
                {
                    data[i] = 1;
                }

                node.GetTable().SetData(data);
            }
            else if (observationNode != null)
            {
                //Compare each item in data table with parent states
                float[] data = node.GetTable().GetData();
                for (int i = 0; i < data.Count(); i++)
                {
                    data[i] = 0;
                    uint[] configuration = new uint[4];
                    node.GetTable().GetConfiguration(ref configuration, (uint)i);

                    string modelState = ((DiscreteNode)parents[0]).GetStateLabel(configuration[0]);
                    string action = ((DiscreteNode)parents[1]).GetStateLabel(configuration[1]);
                    string observation = ((DiscreteNode)parents[2]).GetStateLabel(configuration[2]);
                    string state = node.GetStateLabel(configuration[3]);

                    string compare = modelState + "," + action + "," + observation;

                    if (compare == state)
                    {
                        data[i] = 1;
                    }
                }

                //Find any invalid data in table and replace it with all 1's
                string selectedState = string.Empty;
                for (int i = 0; i < data.Count(); i++)
                {
                    uint[] configuration = new uint[4];
                    node.GetTable().GetConfiguration(ref configuration, (uint)i);
                    string state = node.GetStateLabel(configuration[3]);

                    if (configuration[3] == 0)
                    {
                        //Reset selected state
                        selectedState = string.Empty;
                    }

                    if (data[i] == 1)
                    {
                        selectedState = state;
                    }

                    if (configuration[3] == (node.GetNumberOfStates() - 1) && selectedState == string.Empty)
                    {
                        //if nothing has been selected then set past n data points to 1 as impossible histories
                        //but need to prevent net compile errors
                        for (int j = i - ((int)node.GetNumberOfStates() - 1); j <= i; j++)
                        {
                            data[j] = 1;
                        }
                    }
                }

                node.GetTable().SetData(data);
            }
        }

        /// <summary>
        /// node will query each tree with a given history to determine the action to take
        /// </summary>
        /// <param name="node">node to reset</param>
        /// <param name="trees">trees to search</param>
        public static void SetJActionNodeTablesFromTree(this DiscreteNode node, List<PolicyNode> trees)
        {
            DiscreteNode parent = (DiscreteNode)node.GetParents().First();

            float[] data = node.GetTable().GetData();
            for (int i = 0; i < data.Count(); i++)
            {
                data[i] = 0;

                //first is parent state, second value is own state
                uint[] configuration = new uint[2];
                node.GetTable().GetConfiguration(ref configuration, (uint)i);

                List<string> history = parent.GetStateLabel(configuration[0]).Split(',').ToList();
                string model = history.First();
                int modelNum = Convert.ToInt32(Regex.Match(model, @"\d+").Value);
                history.RemoveAt(0);
                string action = node.GetStateLabel(configuration[1]);

                //Enter histroy string into root node of model tree, if action == action then data[i] = 1
                if (action == trees[modelNum].GetActionForHistory(history))
                {
                    data[i] = 1;
                }
            }

            node.GetTable().SetData(data);
        }

        /// <summary>
        /// Chooses a random state
        /// </summary>
        /// <param name="node">node</param>
        /// <returns>index of chosen state for later use</returns>
        public static uint SetRandomState(this DiscreteNode node)
        {
            Random rand = new Random();
            double randNumber = rand.NextDouble();

            for (uint i = 0; i < node.GetNumberOfStates(); i++)
            {
                randNumber = (randNumber - node.GetBelief(i));
                if (randNumber <= 0)
                {
                    node.SelectState(i);
                    return i;
                }
            }

            return 0;
        }

        /// <summary>
        /// Sets the state with the highest possible reward for a decision node
        /// </summary>
        /// <param name="node">node</param>
        /// <returns>state with highest reward</returns>
        public static uint SetStateWithHighestReward(this DiscreteNode node)
        {
            float maxUtil = 0.0f;
            uint maxState = 0;

            for (uint i = 0; i < node.GetNumberOfStates(); i++)
            {
                if (i == 0)
                {
                    maxUtil = node.GetExpectedUtility(i);
                    maxState = i;
                }
                else if (node.GetExpectedUtility(i) > maxUtil)
                {
                    maxUtil = node.GetExpectedUtility(i);
                    maxState = i;
                }
            }

            node.SelectState(maxState);
            return maxState;
        }
    }
}
